{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial 11: DeepWalk and node2vec - Implementation details\n",
    "  \n",
    "\n",
    "Paper:\n",
    "* [DeepWalk: Online Learning of Social Representation](https://arxiv.org/pdf/1403.6652.pdf)  \n",
    "* [node2vec: Scalable Feature Learning for Networks](https://arxiv.org/pdf/1607.00653.pdf)  \n",
    "\n",
    "Code:\n",
    "\n",
    " * [node2vec doc](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html?highlight=node2vec#torch_geometric.nn.models.Node2Vec)\n",
    " * [node2vec code](https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/node2vec.html)\n",
    " * [Example on clustering](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/node2vec.py)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html\n",
    "!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu113.html\n",
    "!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.nn import Node2Vec\n",
    "import os.path as osp\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.manifold import TSNE\n",
    "from torch_geometric.datasets import Planetoid\n",
    "from tqdm.notebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'Cora'\n",
    "path = osp.join('.', 'data', dataset)\n",
    "dataset = Planetoid(path, dataset)\n",
    "data = dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "model = Node2Vec(data.edge_index, embedding_dim=128, \n",
    "                 walk_length=20,                        # lenght of rw\n",
    "                 context_size=10, walks_per_node=20,\n",
    "                 num_negative_samples=1, \n",
    "                 p=200, q=1,                             # bias parameters\n",
    "                 sparse=True).to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Random walks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The data loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loader = model.loader(batch_size=128, shuffle=True, num_workers=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for idx, (pos_rw, neg_rw) in enumerate(loader):\n",
    "    print(idx, pos_rw.shape, neg_rw.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx, (pos_rw, neg_rw) = next(enumerate(loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(pos_rw.shape, neg_rw.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pos_rw"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "neg_rw"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx \n",
    "edge_tuples = [tuple(x) for x in data.edge_index.numpy().transpose()]\n",
    "G = nx.from_edgelist(edge_tuples)\n",
    "pos = nx.spring_layout(G, center=[0.5, 0.5])\n",
    "nx.set_node_attributes(G, pos, 'pos')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nodelist = next(enumerate(loader))[1][0][0].tolist()\n",
    "walk = nx.path_graph(len(nodelist))\n",
    "nx.set_node_attributes(walk, {idx: pos[node_id] for idx, node_id in enumerate(nodelist)}, 'pos')\n",
    "\n",
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(1, 2, 1)\n",
    "nx.draw_networkx_nodes(G, \n",
    "                       ax=ax,\n",
    "                       pos=nx.get_node_attributes(G, 'pos'), \n",
    "                       node_size=5,\n",
    "                       alpha=0.3,\n",
    "                       node_color='b')\n",
    "nx.draw(walk, \n",
    "        node_size=40,\n",
    "        node_color='r',\n",
    "        ax=ax,\n",
    "        pos=nx.get_node_attributes(walk, 'pos'), \n",
    "        width=2,\n",
    "        edge_color='r') \n",
    "ax = fig.add_subplot(1, 2, 2)\n",
    "nx.draw(walk, \n",
    "        node_size=40,\n",
    "        node_color='r',\n",
    "        ax=ax,\n",
    "        pos=nx.get_node_attributes(walk, 'pos'), \n",
    "        width=2,\n",
    "        edge_color='r') "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model definition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Node2Vec(data.edge_index, embedding_dim=128, walk_length=20,\n",
    "                 context_size=10, walks_per_node=10,\n",
    "                 num_negative_samples=1, p=1, q=1, sparse=True).to(device)\n",
    "\n",
    "loader = model.loader(batch_size=128, shuffle=True, num_workers=4)\n",
    "optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.01)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train():\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    for pos_rw, neg_rw in tqdm(loader):\n",
    "        optimizer.zero_grad()\n",
    "        loss = model.loss(pos_rw.to(device), neg_rw.to(device))\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        total_loss += loss.item()\n",
    "    return total_loss / len(loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data.y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data.y.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def test():\n",
    "    model.eval()\n",
    "    z = model()\n",
    "    acc = model.test(z[data.train_mask], data.y[data.train_mask],\n",
    "                     z[data.test_mask], data.y[data.test_mask],\n",
    "                     max_iter=150)\n",
    "    return acc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(1, 201):\n",
    "    loss = train()\n",
    "    acc = test()\n",
    "    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Acc: {acc:.4f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def plot_points(colors):\n",
    "    model.eval()\n",
    "    z = model(torch.arange(data.num_nodes, device=device))\n",
    "    z = TSNE(n_components=2).fit_transform(z.cpu().numpy())\n",
    "    y = data.y.cpu().numpy()\n",
    "\n",
    "    plt.figure(figsize=(8, 8))\n",
    "    for i in range(dataset.num_classes):\n",
    "        plt.scatter(z[y == i, 0], z[y == i, 1], s=20, color=colors[i])\n",
    "    plt.axis('off')\n",
    "    plt.show()\n",
    "\n",
    "colors = [\n",
    "    '#ffc0cb', '#bada55', '#008080', '#420420', '#7fe5f0', '#065535',\n",
    "    '#ffd700'\n",
    "]\n",
    "plot_points(colors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
